{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "learningtopaint.ipynb",
      "version": "0.3.2",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/hzwer/LearningToPaint/blob/master/LearningToPaint.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "metadata": {
        "id": "TFN3oT1Hkjfs",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "!git clone https://github.com/hzwer/LearningToPaint.git"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "Dp7N29tGkwQs",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "cd LearningToPaint/"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "qTbhmFyawzhO",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "Testing "
      ]
    },
    {
      "metadata": {
        "id": "z0wTTzOEbvps",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "!wget \"https://drive.google.com/uc?export=download&id=1-7dVdjCIZIxh8hHJnGTK-RA1-jL1tor4\" -O renderer.pkl"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "Pfd53Hw2cfaY",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "!wget \"https://drive.google.com/uc?export=download&id=1a3vpKgjCVXHON4P7wodqhCgCMPgg1KeR\" -O actor.pkl"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "QZpb3_3QiMZw",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "!wget -U NoSuchBrowser/1.0 -O image/test.png https://raw.githubusercontent.com/hzwer/LearningToPaint/master/image/Trump.png"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "brX4ZlQoc9ss",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "!python3 baseline/test.py --max_step=80 --actor=actor.pkl --renderer=renderer.pkl --img=image/test.png --divide=5"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "tLM4U6F0_yjV",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "!ffmpeg -r 30 -f image2 -i output/generated%d.png -s 512x512 -c:v libx264 -pix_fmt yuv420p video.mp4 -q:v 0 -q:a 0"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "ekY7HcBeh8zl",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "from IPython.display import display, Image\n",
        "import moviepy.editor as mpy\n",
        "display(mpy.ipython_display('video.mp4', height=256, max_duration=100.))\n",
        "display(Image('output/generated399.png'))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "d2mAkgRjwwuf",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "Training"
      ]
    },
    {
      "metadata": {
        "id": "_-p0NhqyTqO_",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "!mkdir data"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "XXAV9RwkTwKh",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "cd data"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "IzZUVjdrET2G",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "!gdown https://drive.google.com/uc?id=0B7EVK8r0v71pZjFTYXZWM3FlRnM"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "zgguAW3eETVd",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "!unzip img_align_celeba.zip"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "XBH--DY-sK8V",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "!rm img_align_celeba.zip"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "u6mVpjvBvzrb",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "cd .."
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "-PYJVt8pc6BP",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "!python3 baseline/train_renderer.py"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "VZWjNmD23gKm",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "!pip install tensorboardX"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "ehnzhWn9GG4I",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "%%writefile baseline/env.py\n",
        "import sys\n",
        "import json\n",
        "import torch\n",
        "import numpy as np\n",
        "import argparse\n",
        "import torchvision.transforms as transforms\n",
        "import cv2\n",
        "from DRL.ddpg import decode\n",
        "from utils.util import *\n",
        "from PIL import Image\n",
        "from torchvision import transforms, utils\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "aug = transforms.Compose(\n",
        "            [transforms.ToPILImage(),\n",
        "             transforms.RandomHorizontalFlip(),\n",
        "             ])\n",
        "\n",
        "width = 128\n",
        "convas_area = width * width\n",
        "\n",
        "img_train = []\n",
        "img_test = []\n",
        "train_num = 0\n",
        "test_num = 0\n",
        "\n",
        "class Paint:\n",
        "    def __init__(self, batch_size, max_step):\n",
        "        self.batch_size = batch_size\n",
        "        self.max_step = max_step\n",
        "        self.action_space = (13)\n",
        "        self.observation_space = (self.batch_size, width, width, 7)\n",
        "        self.test = False\n",
        "        \n",
        "    def load_data(self):\n",
        "        # CelebA\n",
        "        global train_num, test_num\n",
        "        for i in range(200000):\n",
        "            img_id = '%06d' % (i + 1)\n",
        "            try:\n",
        "                img = cv2.imread('./data/img_align_celeba/' + img_id + '.jpg', cv2.IMREAD_UNCHANGED)\n",
        "                img = cv2.resize(img, (width, width))\n",
        "                if i > 2000:                \n",
        "                    train_num += 1\n",
        "                    img_train.append(img)\n",
        "                else:\n",
        "                    test_num += 1\n",
        "                    img_test.append(img)\n",
        "            finally:\n",
        "                if (i + 1) % 10000 == 0:                    \n",
        "                    print('loaded {} images'.format(i + 1))\n",
        "        print('finish loading data, {} training images, {} testing images'.format(str(train_num), str(test_num)))\n",
        "        \n",
        "    def pre_data(self, id, test):\n",
        "        if test:\n",
        "            img = img_test[id]\n",
        "        else:\n",
        "            img = img_train[id]\n",
        "        if not test:\n",
        "            img = aug(img)\n",
        "        img = np.asarray(img)\n",
        "        return np.transpose(img, (2, 0, 1))\n",
        "    \n",
        "    def reset(self, test=False, begin_num=False):\n",
        "        self.test = test\n",
        "        self.imgid = [0] * self.batch_size\n",
        "        self.gt = torch.zeros([self.batch_size, 3, width, width], dtype=torch.uint8).to(device)\n",
        "        for i in range(self.batch_size):\n",
        "            if test:\n",
        "                id = (i + begin_num)  % test_num\n",
        "            else:\n",
        "                id = np.random.randint(train_num)\n",
        "            self.imgid[i] = id\n",
        "            self.gt[i] = torch.tensor(self.pre_data(id, test))\n",
        "        self.tot_reward = ((self.gt.float() / 255) ** 2).mean(1).mean(1).mean(1)\n",
        "        self.stepnum = 0\n",
        "        self.canvas = torch.zeros([self.batch_size, 3, width, width], dtype=torch.uint8).to(device)\n",
        "        self.lastdis = self.ini_dis = self.cal_dis()\n",
        "        return self.observation()\n",
        "    \n",
        "    def observation(self):\n",
        "        # canvas B * 3 * width * width\n",
        "        # gt B * 3 * width * width\n",
        "        # T B * 1 * width * width\n",
        "        ob = []\n",
        "        T = torch.ones([self.batch_size, 1, width, width], dtype=torch.uint8) * self.stepnum\n",
        "        return torch.cat((self.canvas, self.gt, T.to(device)), 1) # canvas, img, T\n",
        "\n",
        "    def cal_trans(self, s, t):\n",
        "        return (s.transpose(0, 3) * t).transpose(0, 3)\n",
        "    \n",
        "    def step(self, action):\n",
        "        self.canvas = (decode(action, self.canvas.float() / 255) * 255).byte()\n",
        "        self.stepnum += 1\n",
        "        ob = self.observation()\n",
        "        done = (self.stepnum == self.max_step)\n",
        "        reward = self.cal_reward() # np.array([0.] * self.batch_size)\n",
        "        return ob.detach(), reward, np.array([done] * self.batch_size), None\n",
        "\n",
        "    def cal_dis(self):\n",
        "        return (((self.canvas.float() - self.gt.float()) / 255) ** 2).mean(1).mean(1).mean(1)\n",
        "    \n",
        "    def cal_reward(self):\n",
        "        dis = self.cal_dis()\n",
        "        reward = (self.lastdis - dis) / (self.ini_dis + 1e-8)\n",
        "        self.lastdis = dis\n",
        "        return to_numpy(reward)\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "0kwVmo6yv1w3",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "!python3 baseline/train.py --max_step=200 --debug --batch_size=96"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}